import numpy as np
import numpy.linalg as LA
from sympy import re, im, I, E, Symbol, sqrt
import argparse


def g_CauchyK_num(S):
    z = Symbol('z')
    ret = 0
    N = len(S)
    
    for j in range(N):

        ret += 1/(z + S[j] - I*np.sqrt(1/(2*N)) )
        ret += 1/(z - S[j]- I*np.sqrt(1/(2*N)) )
    
    return ret/(2*N)

def Estimator(S_s, gX, gS, SNR, alpha):
    
    N = len(S_s)
    
    output_X = np.zeros(N)
    output_XX =np.zeros(N)
    
    z = Symbol('z')
    
    for i in range(N):
        
        #### optimal eigenvalue for X
        zz = S_s[i] -  I*np.sqrt(1/(2*N))
        gS_eval = gS.subs(z,zz).evalf()
        zeta = gS_eval + ((1-alpha)/alpha)*(1/zz)
        
        Z = (zz/zeta -1)/SNR
        
        Est = gX.subs(z,sqrt(Z)).evalf() + gX.subs(z,-sqrt(Z)).evalf()
        
        output_X[i] = im(((Est/zeta)/(2*SNR*im(gS_eval))).evalf())
        
        #### optimal eigenvalue for X^2
        output_XX[i]  = ( -1 + 1 /( alpha * ( im(gS_eval)**2 + ( re(gS_eval) + (-1 + 1/alpha )/S_s[i] )**2 ) ) )/SNR
    
    return output_X, output_XX


def main():
    
    z = Symbol('z')
    p = argparse.ArgumentParser()

    p.add_argument('-a', type=float)
    p.add_argument('-s', type=float)
    p.add_argument('-p', type=float)
    
    args = p.parse_args()
    
    a = args.a
    sparsity = args.p
    SNR = args.s

    N = 2000
    M = int(N/a)
    
    Ex = 10
    
        
    E_X_oracle = np.zeros(Ex)
    E_X_RIE = np.zeros(Ex)
    E_X_sqXX = np.zeros(Ex)
        
    E_XX_oracle = np.zeros(Ex)
    E_XX_RIE = np.zeros(Ex)

    for i in range(Ex):
        
        X_e = np.random.rand(N)
        X_e[X_e<sparsity] = 0
        X_e[X_e>sparsity] = 1
        G = np.triu(np.random.normal(0, 1, (N,N)),1)
        G = G + np.transpose(G) +  np.diag(np.random.normal(loc=0, scale=np.sqrt(2), size=(N)))
        ev, U = LA.eigh(G)
        X = U @ np.diag(X_e) @ np.transpose(U)
            
        gX = sparsity*(1/z) + (1-sparsity)*(1/(z-1))

    
        ## Noise
        Y = np.random.randn(N,M)
        Y = Y/np.sqrt(N)
    
        W = np.random.randn(N,M)
        W = W/np.sqrt(N)


        ### Observation
        S = np.sqrt(SNR) * X @ Y + W
    
        ### SVD
        U_s, S_s , Vh_s = LA.svd(S)

        gS = g_CauchyK_num(S_s)

        ### Oracle Estimator for X & X^2
        e_hat_X_oracle = np.zeros(N)
        e_hat_XX_oracle = np.zeros(N)
            
        XX = X @ X
        
        X_norm = LA.norm(X)**2
        XX_norm = LA.norm(XX)**2
        
        for k in range(N):
            e_hat_X_oracle[k] = np.transpose(U_s[:,k])@X@U_s[:,k]
            e_hat_XX_oracle[k] = np.transpose(U_s[:,k])@XX@U_s[:,k]
                
        X_hat_oracle = U_s@np.diag(e_hat_X_oracle)@np.transpose(U_s)
        XX_hat_oracle = U_s@np.diag(e_hat_XX_oracle)@np.transpose(U_s)
        
        E_X_oracle[i] = ( LA.norm(X-X_hat_oracle)**2) / X_norm
        E_XX_oracle[i] = ( LA.norm(XX-XX_hat_oracle)**2 ) / XX_norm



        #### RIE for X & X^2
        e_hat_X, e_hat_XX = Estimator(S_s, gX, gS, SNR, a)
        
        X_hat = U_s@np.diag(e_hat_X)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(X-X_hat)**2 ) / X_norm
        
        e_hatX_sqXX = np.zeros(N)
        for k in range(N):
            if e_hat_XX[i] >= 0 :
                e_hatX_sqXX[i] = np.sqrt(e_hat_XX[i])
            else:
                e_hatX_sqXX[i] = 0
            
        X_hat_sqXX = U_s@np.diag(e_hatX_sqXX)@np.transpose(U_s)
        E_X_sqXX[i] = ( LA.norm(X-X_hat_sqXX)**2 ) / X_norm
        
        XX_hat = U_s@np.diag(e_hat_XX)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(XX-XX_hat)**2 ) / XX_norm


    filename = 'X-Bernoulli_sparsity_'+str(sparsity)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_X_oracle)
    
    filename = 'XX-Bernoulli_sparsity_'+str(sparsity)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_XX_oracle)
    
    filename = 'X-Bernoulli_sparsity_'+str(sparsity)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_X_RIE)
    
    filename = 'X-Bernoulli_sparsity_'+str(sparsity)+'_SNR='+str(SNR)+'_sqXX.npy'
    np.save( filename, E_X_sqXX)

    filename = 'XX-Bernoulli_sparsity_'+str(sparsity)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_XX_RIE)

#
if __name__ == "__main__":
    main()
    
